02. PyTorch Intro
PyTorch
PyTorch is an open source Python-based deep learning framework released in October 2016. It was inspired by the Torch framework, which was originally implemented in C with a wrapper in the Lua scripting language. PyTorch wraps the core Torch binaries in Python and provides GPU acceleration for many functions.
Dynamic Neural Networks
Most frameworks such as TensorFlow, Theano, and Caffe require static graphs to define networks. The network must first be built, then run. Any change in the network structure requires building from scratch. By contrast, PyTorch allows you to change the way your network behaves on the fly.
Imperative experiences
PyTorch is designed to be intuitive and linear. Code is executed in-line as Python is. This is a help in debugging as the stack-trace points to exactly where code was defined.
Numpy-like
PyTorch tensors can be used as a replacement for numpy
, with GPU acceleration.
Quiz 1: PyTorch Features
SOLUTION:
- PyTorch can perform tensor computation, like `numpy`, with strong GPU acceleration.
- PyTorch builds computation graphs dynamically, which provides greater flexibility during development.
- PyTorch is designed to be intuitive, linear in thought, and easy to use. When you input a line of code, it gets executed.
Quiz 2: PyTorch Networks
The following python class defines a neural network using PyTorch neural network library, torch.nn
. The DQN network will read in pixels and produce actions. Although you haven’t studied the syntax for this library yet, try your best to answer the quiz.
class DQN(nn.Module):
def __init__(self):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
self.bn3 = nn.BatchNorm2d(32)
self.head = nn.Linear(448, 2)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
return self.head(x.view(x.size(0), -1))